{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib import cm\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'png'\n",
    "matplotlib.rcParams['figure.figsize'] = (12.0, 4.0)\n",
    "matplotlib.rcParams['font.size'] = 7\n",
    "\n",
    "import matplotlib.lines as mlines\n",
    "import seaborn\n",
    "seaborn.set_style('darkgrid')\n",
    "import logging\n",
    "import importlib\n",
    "importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195\n",
    "log = logging.getLogger()\n",
    "log.setLevel('DEBUG')\n",
    "import sys\n",
    "logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n",
    "                     level=logging.DEBUG, stream=sys.stdout)\n",
    "seaborn.set_palette('colorblind')\n",
    "import os\n",
    "# add the repo itself\n",
    "os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')\n",
    "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datasets.bbci import BBCIDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnt = BBCIDataset('/data/schirrmr/schirrmr/HGD-public/reduced/train/8.mat', load_sensor_names=['C3', 'C4', 'Cz']).load()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In trialwise decoding, you forward trials through your network and get one prediction per trial. Then you train the network using the loss between the predictions and the labels for the trials.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "data = cnt.get_data()[0,100:201]\n",
    "plt.plot(data - np.mean(data))\n",
    "ax = plt.gca()\n",
    "plt.plot([0,0],[-20,20], color='black')\n",
    "plt.plot([100,100],[-20,20], color='black')\n",
    "points = [[0, -20], [50, -80], [100, -20]]\n",
    "polygon = plt.Polygon(points,  facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
    "ax.add_artist(polygon)\n",
    "plt.ylim(-100,22)\n",
    "plt.text(50,30, \"Trial\", fontsize=20, ha='center')\n",
    "plt.text(50,-40, \"Network\", fontsize=20, ha='center')\n",
    "plt.plot(50,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
    "plt.text(45,-80, \"Prediction\", fontsize=20, ha='right', va='center')\n",
    "plt.plot(50,-95, ls='', marker='o')\n",
    "plt.text(45,-95, \"Target\", fontsize=20, ha='right', va='center')\n",
    "plt.axis('off')\n",
    "\n",
    "plt.plot([50,65],[-80, -87.5], color='black')\n",
    "plt.plot([50,65],[-95, -87.5], color='black')\n",
    "plt.text(67,-87.5, \"Loss\", fontsize=20, ha='left', va='center')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "data = cnt.get_data()[0,100:201]\n",
    "plt.plot(data - np.mean(data))\n",
    "ax = plt.gca()\n",
    "plt.plot([0,0],[-20,20], color='black')\n",
    "plt.plot([100,100],[-20,20], color='black')\n",
    "points = [[0, -20], [25, -80], [50, -20]]\n",
    "polygon = plt.Polygon(points,  facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
    "ax.add_artist(polygon)\n",
    "points = [[3, -20], [28, -80], [53, -20]]\n",
    "polygon = plt.Polygon(points,  facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
    "ax.add_artist(polygon)\n",
    "points = [[6, -20], [31, -80], [56, -20]]\n",
    "polygon = plt.Polygon(points,  facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
    "ax.add_artist(polygon)\n",
    "plt.plot(25,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
    "plt.plot(28,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
    "plt.plot(31,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
    "plt.ylim(-100,22)\n",
    "plt.text(50,35, \"Trial\", fontsize=20, ha='center')\n",
    "plt.text(28,-38, \"Network\", fontsize=16, ha='center')\n",
    "plt.text(20,-80, \"Predictions\", fontsize=20, ha='right', va='center')\n",
    "plt.plot(28,-95, ls='', marker='o')\n",
    "plt.text(20,-95, \"Target\", fontsize=20, ha='right', va='center')\n",
    "\n",
    "# from https://stackoverflow.com/a/20308475/1469195\n",
    "def range_brace(x_min, x_max, mid=0.75, \n",
    "                beta1=50.0, beta2=100.0, height=1, \n",
    "                initial_divisions=11, resolution_factor=1.5):\n",
    "    NP = np\n",
    "    # determine x0 adaptively values using second derivitive\n",
    "    # could be replaced with less snazzy:\n",
    "    #   x0 = NP.arange(0, 0.5, .001)\n",
    "    x0 = NP.array(())\n",
    "    tmpx = NP.linspace(0, 0.5, initial_divisions)\n",
    "    tmp = beta1**2 * (NP.exp(beta1*tmpx)) * (1-NP.exp(beta1*tmpx)) / NP.power((1+NP.exp(beta1*tmpx)),3)\n",
    "    tmp += beta2**2 * (NP.exp(beta2*(tmpx-0.5))) * (1-NP.exp(beta2*(tmpx-0.5))) / NP.power((1+NP.exp(beta2*(tmpx-0.5))),3)\n",
    "    for i in range(0, len(tmpx)-1):\n",
    "        t = int(NP.ceil(resolution_factor*max(NP.abs(tmp[i:i+2]))/float(initial_divisions)))\n",
    "        x0 = NP.append(x0, NP.linspace(tmpx[i],tmpx[i+1],t))\n",
    "    x0 = NP.sort(NP.unique(x0)) # sort and remove dups\n",
    "    # half brace using sum of two logistic functions\n",
    "    y0 = mid*2*((1/(1.+NP.exp(-1*beta1*x0)))-0.5)\n",
    "    y0 += (1-mid)*2*(1/(1.+NP.exp(-1*beta2*(x0-0.5))))\n",
    "    # concat and scale x\n",
    "    x = NP.concatenate((x0, 1-x0[::-1])) * float((x_max-x_min)) + x_min\n",
    "    y = NP.concatenate((y0, y0[::-1])) * float(height)\n",
    "    return (x,y)\n",
    "\n",
    "\n",
    "x,y = range_brace(0, 50, height=6)\n",
    "ax.plot(x, y-20,'-')\n",
    "\n",
    "ax.annotate('Crop',  xy=(25, -10), xycoords='data', \n",
    "            fontsize=13, ha='center', va='bottom', color=seaborn.color_palette()[2],\n",
    "            bbox=dict(boxstyle='square', fc='white', ec='None',))\n",
    "\n",
    "x,y = range_brace(0, 56, height=6, mid=0.5, beta1=50)\n",
    "ax.plot(x, y+10,'-', color=seaborn.color_palette()[2])\n",
    "plt.plot([56,56], [-20,10], ls='--', color=seaborn.color_palette()[2])\n",
    "ax.annotate('Supercrop',  xy=(28, 20), xycoords='data', \n",
    "            fontsize=13, ha='center', va='bottom', color=seaborn.color_palette()[2],\n",
    "            bbox=dict(boxstyle='square', fc='white', ec='None',))\n",
    "x,y = range_brace(25, 31, height=3.5, mid=0.5, beta1=50)\n",
    "ax.plot(x, -y-82,'-', color=seaborn.color_palette()[2])\n",
    "plt.plot(28,-88, ls='', marker='o',  color=seaborn.color_palette()[2])\n",
    "\n",
    "plt.plot([28,40],[-88, -91.5], color='black')\n",
    "plt.plot([28,40],[-95, -91.5], color='black')\n",
    "plt.text(42,-91.5, \"Loss\", fontsize=20, ha='left', va='center')\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "import matplotlib.transforms as mtrans\n",
    "from matplotlib.text import TextPath\n",
    "from matplotlib.patches import PathPatch\n",
    "def curly(x,y, size, scale, rotation, ax=None):\n",
    "    tp = TextPath((0, 0), \"}\", size=size,)\n",
    "    #trans = mtrans.Affine2D().scale(1, scale) + mtrans.Affine2D().rotate_deg(rotation) + \\\n",
    "    #    mtrans.Affine2D().translate(x,y) + ax.transData\n",
    "    trans = mtrans.Affine2D().rotate_deg(rotation) + mtrans.Affine2D().translate(x,y) + ax.transData\n",
    "    pp = PathPatch(tp, lw=0, fc=\"k\", transform=trans)\n",
    "    ax.add_artist(pp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "data = cnt.get_data()[0,100:201]\n",
    "plt.plot(data - np.mean(data))\n",
    "ax = plt.gca()\n",
    "plt.plot([0,0],[-20,20], color='black')\n",
    "plt.plot([100,100],[-20,20], color='black')\n",
    "points = [[0, -20], [25, -80], [50, -20]]\n",
    "polygon = plt.Polygon(points,  facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
    "ax.add_artist(polygon)\n",
    "points = [[3, -20], [28, -80], [53, -20]]\n",
    "polygon = plt.Polygon(points,  facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
    "ax.add_artist(polygon)\n",
    "points = [[6, -20], [31, -80], [56, -20]]\n",
    "polygon = plt.Polygon(points,  facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
    "ax.add_artist(polygon)\n",
    "plt.plot(25,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
    "plt.plot(28,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
    "plt.plot(31,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
    "plt.ylim(-100,22)\n",
    "plt.text(50,30, \"Trial\", fontsize=20, ha='center')\n",
    "plt.text(20,-80, \"Predictions\", fontsize=20, ha='right', va='center')\n",
    "plt.plot(25,-95, ls='', marker='o')\n",
    "plt.text(20,-95, \"Target\", fontsize=20, ha='right', va='center')\n",
    "\n",
    "\n",
    "# Here is the label and arrow code of interest\n",
    "\n",
    "ax.annotate('SDL', xy=(0.5, 0.90), xytext=(0.5, 1.00), xycoords='data', \n",
    "            fontsize=13, ha='center', va='bottom',\n",
    "            bbox=dict(boxstyle='square', fc='white'),\n",
    "            arrowprops=dict(arrowstyle='-[, widthB=7.0, lengthB=1.5', lw=2.0))\n",
    "\"\"\"\n",
    "plt.annotate(r\"$\\}$\",fontsize=46,\n",
    "            xy=(0.27, 0.77), xycoords='figure fraction',\n",
    "            rotation=90)\"\"\"\n",
    "\n",
    "import matplotlib.transforms as mtrans\n",
    "from matplotlib.text import TextPath\n",
    "from matplotlib.patches import PathPatch\n",
    "def curly(x,y, size, scale, rotation, ax=None):\n",
    "    tp = TextPath((0, 0), \"}\", size=size)\n",
    "    #trans = mtrans.Affine2D().scale(1, scale) + mtrans.Affine2D().rotate_deg(rotation) + \\\n",
    "    #    mtrans.Affine2D().translate(x,y) + ax.transData\n",
    "    trans = ax.transData + mtrans.Affine2D().translate(x,y) \n",
    "    pp = PathPatch(tp, lw=0, fc=\"k\", transform=trans)\n",
    "    ax.add_artist(pp)\n",
    "\n",
    "\n",
    "curly(10,-20,20, 0.5, 90, ax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "plt.text(45,-80, \"Prediction\", fontsize=20, ha='right', va='center')\n",
    "plt.plot(50,-95, ls='', marker='o')\n",
    "plt.text(45,-95, \"Target\", fontsize=20, ha='right', va='center')\n",
    "plt.axis('off')\n",
    "\n",
    "plt.plot([50,65],[-80, -87.5], color='black')\n",
    "plt.plot([50,65],[-95, -87.5], color='black')\n",
    "plt.text(67,-87.5, \"Loss\", fontsize=20, ha='left', va='center')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "plt.plot(cnt.get_data()[0,100:201])\n",
    "plt.axvline(x=0, color='black')\n",
    "plt.axvline(x=100, color='black')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
